{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
#include "solver.h"

#include <stdexcept>
#include <chrono>

#include "caspar_mappings.h"
#include "shared_indices.h"
#include "solver_tools.h"
#include "sort_indices.h"

{% for fac in solver.factors%}
#include "kernel_{{fac.name}}_res_jac.h"
#include "kernel_{{fac.name}}_res_jac_first.h"
#include "kernel_{{fac.name}}_score.h"
{% for arg, typ in fac.node_arg_types.items()%}
#include "kernel_{{fac.name}}_{{arg}}_jnjtr.h"
#include "kernel_{{fac.name}}_{{arg}}_jtjnjtr.h"
#include "kernel_{{fac.name}}_{{arg}}_njtr_precond.h"
{% endfor %}
{% endfor %}
{% for nodetype in solver.node_types%}
#include "kernel_{{nodetype.__name__}}_alpha_denumerator_or_beta_nummerator.h"
#include "kernel_{{nodetype.__name__}}_alpha_numerator_denominator.h"
#include "kernel_{{nodetype.__name__}}_normalize.h"
#include "kernel_{{nodetype.__name__}}_pred_decrease.h"
#include "kernel_{{nodetype.__name__}}_retract.h"
#include "kernel_{{nodetype.__name__}}_start_w.h"
#include "kernel_{{nodetype.__name__}}_start_w_contribute.h"
#include "kernel_{{nodetype.__name__}}_update_r.h"
#include "kernel_{{nodetype.__name__}}_update_r_first.h"
#include "kernel_{{nodetype.__name__}}_update_step_first.h"
#include "kernel_{{nodetype.__name__}}_update_step_or_update_p.h"
{% endfor %}

namespace {

void make_aligned(size_t& offset, size_t alignment = 16) {
  offset = ((offset + alignment - 1) / alignment) * alignment;
}

template <typename T>
void increment_offset(size_t& offset, size_t nwords, size_t alignment) {
  make_aligned(offset, alignment);
  offset += nwords * sizeof(T);
}

template <typename T>
T* assign_and_increment(uint8_t* origin_ptr, size_t& offset, size_t nwords, size_t alignment) {
  make_aligned(offset, alignment);
  size_t old_offset = offset;
  offset += nwords * sizeof(T);
  return reinterpret_cast<T*>(origin_ptr + old_offset);
}

}  // namespace

namespace caspar {

{{ solver.struct_name }}::{{ solver.struct_name }}(
    const SolverParams &params,
    {% for thing in solver.size_contributors %}
    size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }}
    {% endfor %}
    )
    : params_(params),
      {% for thing in solver.size_contributors %}
      {{num_key(thing)}}({{num_arg_key(thing)}}),
      {{num_blocks_key(thing)}}(({{num_arg_key(thing)}} + 1023) / 1024),
      {{num_max_key(thing)}}({{num_arg_key(thing)}}){{ ", " if not loop.last else "" }}
      {% endfor %}
      {
  indices_valid_ = false;

  allocation_size_ = get_nbytes();
  cudaMalloc(&origin_ptr_, allocation_size_);

  size_t offset = 0;
  {% for name, field in solver.fields.items() %}
  {{name}} = assign_and_increment<{{field.dtype}}>(
      origin_ptr_, offset, {{field.dim_real}} * {{field.num_key}}, {{field.alignment}});
  {% endfor %}

  scratch_inout_size_ = marker__end_ - marker__scratch_inout_;  // sorting, sum,
  scratch_sum_size_ = marker__scratch_sum_end_ - marker__scratch_sum_;  // sorting, sum,
}

{{ solver.struct_name }}::~{{ solver.struct_name }}(){
  cudaFree(origin_ptr_);
}

void {{ solver.struct_name }}::set_params(const SolverParams &params){
  this->params_ = params;
}

size_t {{ solver.struct_name }}::get_allocation_size(){
  return allocation_size_;
}


float GraphSolver::solve(bool print_progress) {
  float score_best;
  float score_best_pcg;
  float diag = params_.diag_init;
  cudaMemcpy(solver__current_diag_, &diag, sizeof(float), cudaMemcpyHostToDevice);

  if (!indices_valid_) {
    throw std::runtime_error(
        "Indices not valid. Set indices and call finish_indices before running solve().");
  }

  float up_scale = params_.diag_scaling_up;
  float quality;

  std::chrono::time_point<std::chrono::steady_clock> t0 = std::chrono::steady_clock::now();
  std::chrono::time_point<std::chrono::steady_clock> t_prev = t0;
  score_best = do_res_jac_first();
  if (print_progress) {
    printf("                                 score_init: % .6e\n", score_best);
  }

  for (solver_iter_ = 0; solver_iter_ < params_.solver_iter_max; solver_iter_++) {
    if (solver_iter_ != 0 && solver_iter_ < params_.solver_iter_max - 1){
      do_res_jac();
    }
    do_njtr_precond();
    score_best_pcg = score_best;
    for (pcg_iter_ = 0; pcg_iter_ < params_.pcg_iter_max; pcg_iter_++) {
      do_normalize();

      if (pcg_iter_ == 0) {
        do_jp();
        do_jtjp();
        do_alpha_first();
        do_update_step_first();
        do_update_r_first();
      } else {
        do_beta();
        do_update_p();
        do_jp();
        do_jtjp();
        do_alpha();
        do_update_step();
      }
      float score_new_pcg = do_retract_score();
      if (score_new_pcg > score_best_pcg * params_.pcg_rel_decrease_min) {
        pcg_iter_--;  // final step was not accepted
        break;
      }
      if (pcg_iter_ != 0){
        do_update_r();
      }
      score_best_pcg = score_new_pcg;
      {% for node_type in solver.node_types %}
      std::swap({{node_key(node_type, "storage_check")}},
                {{node_key(node_type, "storage_new_best")}});
      {% endfor %}

      if (score_best_pcg < score_best * params_.pcg_rel_score_exit){
        break;
      }
      if (pcg_r_kp1_norm2_ < pcg_r_0_norm2_ * params_.pcg_rel_error_exit){
        break;
      }
    }
    pcg_iter_ = std::min(pcg_iter_, params_.pcg_iter_max - 1);

    if (score_best_pcg < score_best * params_.solver_rel_decrease_min) {
      quality = (score_best - score_best_pcg) / get_pred_decrease();
      const float quality_tmp = 2 * quality - 1;
      float scale =
          std::max(params_.diag_scaling_down, 1.0f - quality_tmp * quality_tmp * quality_tmp);
      diag = std::max(params_.diag_min, diag * scale);
      cudaMemcpy(solver__current_diag_, &diag, sizeof(float), cudaMemcpyHostToDevice);
      up_scale = params_.diag_scaling_up;
      score_best = score_best_pcg;
      {% for node_type in solver.node_types %}
      std::swap({{node_key(node_type, "storage_current")}},
                {{node_key(node_type, "storage_new_best")}});
      {% endfor %}

    } else {
      quality = 0.0f;
      diag = diag * up_scale;
      if (diag > params_.diag_exit_value) {
        break;
      }
      cudaMemcpy(solver__current_diag_, &diag, sizeof(float), cudaMemcpyHostToDevice);
      up_scale *= 2;
    }
    const auto t_now = std::chrono::steady_clock::now();
    if (print_progress) {
      printf("solver_iter: % 3d  ", solver_iter_);
      printf("pcg_iter: % 3d  ", pcg_iter_);
      printf("score_best: % .6e  ", score_best);
      printf("step_quality: % 6.3f  ", quality);
      printf("diag: % 6.3e  ", diag);
      printf("dt_inc: % 10.6f  ", (float)(t_now - t_prev).count() * 1e-9);
      printf("dt_tot: % 10.6f  ", (float)(t_now - t0).count() * 1e-9);
      t_prev = t_now;
      printf("\n");
    }
    if (score_best <= params_.score_exit_value) {
      break;
    }
  }

  return score_best;
}

float {{ solver.struct_name }}::do_res_jac_first() {
  zero(marker__res_tot_start_, marker__res_tot_end_);
  {% for fac in solver.factors %}
  {{fac.name}}_res_jac_first(
      {% for arg, typ in fac.arg_types.items() %}
      {% if fac.isnode[arg] %}
          {{node_key(typ, "storage_current")}},
          {{num_max_key(typ)}},
          {{arg_key(fac, arg, "idx_shared")}},
      {% else %}
          {{arg_key(fac, arg, "data")}},
          {{num_max_key(fac)}},
      {% endif%}
      {% endfor %}
      {{fac_key(fac, "res")}},
      {{num_max_key(fac)}},
      {{fac_key(fac, "res_tot")}},
      {{num_blocks_key(fac)}},
      {% for arg, typ in fac.node_arg_types.items()%}
          {{arg_key(fac, arg, "jac")}},
          {{num_max_key(fac)}},
          {{arg_key(fac, arg, "idx_target")}},
      {% endfor %}
      {{num_key(fac)}}
  );
  {% endfor %}
  return 0.5 * sum(marker__res_tot_start_, marker__res_tot_end_, marker__scratch_sum_,
                   marker__scratch_sum_ + 1);
}

void {{ solver.struct_name }}::do_res_jac() {
    {% for fac in solver.factors %}
    {{fac.name}}_res_jac(
        {% for arg, typ in fac.arg_types.items() %}
        {% if fac.isnode[arg] %}
            {{node_key(typ, "storage_current")}},
            {{num_max_key(typ)}},
            {{arg_key(fac, arg, "idx_shared")}},
        {% else %}
            {{arg_key(fac, arg, "data")}},
            {{num_max_key(fac)}},
        {% endif%}
        {% endfor %}
        {{fac_key(fac, "res")}},
        {{num_max_key(fac)}},
        {% for arg, typ in fac.node_arg_types.items()%}
            {{arg_key(fac, arg, "jac")}},
            {{num_max_key(fac)}},
            {{arg_key(fac, arg, "idx_target")}},
        {% endfor %}
        {{num_key(fac)}}
    );
    {% endfor %}
}

void {{ solver.struct_name }}::do_njtr_precond() {
    zero(marker__r_0_start_, marker__precond_end_);
    {% for fac in solver.factors %}
    {% for arg, typ in fac.node_arg_types.items() %}
    {{fac.name}}_{{arg}}_njtr_precond(
        {{fac_key(fac, "res")}}, {{num_max_key(fac)}}, {{arg_key(fac, arg, "idx_argsort")}},
        {{arg_key(fac, arg, "jac")}}, {{num_max_key(fac)}},
        {{node_key(typ, "r_0")}}, {{num_max_key(typ)}}, {{arg_key(fac, arg, "idx_sorted_shared")}},
        {{node_key(typ, "precond_diag")}}, {{num_max_key(typ)}},
        {{node_key(typ, "precond_tril")}}, {{num_max_key(typ)}},
        {{num_key(fac)}}
    );
    {% endfor %}
    {% endfor %}
    copy(marker__r_0_start_, marker__precond_start_, marker__r_k_a_start_);
}

void {{ solver.struct_name }}::do_normalize() {
    float * r_k;
    float * z;
    {% for nodetype in solver.node_types%}
    r_k = pcg_iter_ % 2 ? {{node_key(nodetype, "r_k_b")}} : {{node_key(nodetype, "r_k_a")}};
    z = pcg_iter_ == 0 ? {{node_key(nodetype, "p_k_b")}} : {{node_key(nodetype, "z")}};
    {{nodetype.__name__}}_normalize(
        {{node_key(nodetype, "precond_diag")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "precond_tril")}}, {{num_max_key(nodetype)}},
        r_k, {{num_max_key(nodetype)}},
        solver__current_diag_,
        z, {{num_max_key(nodetype)}},
        {{num_key(nodetype)}});
    {% endfor%}
}

void {{ solver.struct_name }}::do_jp() {
    zero(marker__jp_start_, marker__jp_end_);
    float* p_kp1;
    {% for fac in solver.factors %}
    {% for arg, typ in fac.jnjtr_args.items() %}
    p_kp1 = pcg_iter_ % 2 ? {{node_key(typ, "p_k_a")}} : {{node_key(typ, "p_k_b")}};
    {{fac.name}}_{{arg}}_jnjtr(
        {{arg_key(fac, arg, "jac")}}, {{num_max_key(fac)}},
        p_kp1, {{num_max_key(typ)}}, {{arg_key(fac, arg, "idx_sorted_shared")}},
        {{fac_key(fac, "jp")}},
        {{num_max_key(fac)}},
        {% if arg != fac.smallest_node %}
            {{arg_key(fac, arg, "idx_jp_target")}},
        {% endif %}
        {{num_key(fac)}}
    );
    {% endfor %}
    {% endfor %}
}

void {{ solver.struct_name }}::do_jtjp() {
    float* p_kp1;

    {% for nodetype in solver.node_types %}
    p_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "p_k_a")}} : {{node_key(nodetype, "p_k_b")}};
    {{nodetype.__name__}}_start_w(
        {{node_key(nodetype, "precond_diag")}},
        {{num_max_key(nodetype)}},
        solver__current_diag_,
        p_kp1,
        {{num_max_key(nodetype)}},
        {{node_key(nodetype, "w")}},
        {{num_max_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}

    {% for fac in solver.factors %}
    if ({{num_key(fac)}}) {
        {% for arg, typ in fac.jnjtr_args.items() %}
        {{fac.name}}_{{arg}}_jtjnjtr(
            {{arg_key(fac, arg, "jac")}},
            {{num_max_key(fac)}},
            {{fac_key(fac, "jp")}},
            {{num_max_key(fac)}},
            {% if arg!=fac.smallest_node %}
                {{arg_key(fac, arg, "idx_jp_target")}},
            {% endif %}
            {{node_key(typ, "w")}},
            {{num_max_key(typ)}},
            {{arg_key(fac, arg, "idx_sorted_shared")}},
            {{num_key(fac)}}
        );
        {% endfor %}
    }
    {% endfor %}
}

void {{ solver.struct_name }}::do_alpha_first() {
    zero(marker__alpha_numerator_tot_start_, marker__alpha_denumerator_tot_end_);
    float* p_kp1;
    float* r_k;
    {% for nodetype in solver.node_types %}
    r_k = pcg_iter_ % 2 ? {{node_key(nodetype, "r_k_b")}} : {{node_key(nodetype, "r_k_a")}};
    p_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "p_k_a")}} : {{node_key(nodetype, "p_k_b")}};
    {{nodetype.__name__}}_alpha_numerator_denominator(
        p_kp1, {{num_max_key(nodetype)}},
        r_k, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "w")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "alpha_numerator_tot")}}, {{num_blocks_key(nodetype)}},
        {{node_key(nodetype, "alpha_denumerator_tot")}}, {{num_blocks_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}

    sum(marker__alpha_numerator_tot_start_, marker__alpha_numerator_tot_end_,
        solver__alpha_numerator_, marker__scratch_sum_, false);
    sum(marker__alpha_denumerator_tot_start_, marker__alpha_denumerator_tot_end_,
        solver__alpha_denumerator_, marker__scratch_sum_, false);
    alpha_from_num_denum(solver__alpha_numerator_, solver__alpha_denumerator_, solver__alpha_,
                         solver__neg_alpha_);
}

void {{ solver.struct_name }}::do_alpha() {
    zero(marker__alpha_denumerator_tot_start_, marker__alpha_denumerator_tot_end_);
    float* p_kp1;
    {% for nodetype in solver.node_types %}
    p_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "p_k_a")}} : {{node_key(nodetype, "p_k_b")}};
    {{nodetype.__name__}}_alpha_denumerator_or_beta_nummerator(
        p_kp1, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "w")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "alpha_denumerator_tot")}}, {{num_blocks_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}

    sum(marker__alpha_denumerator_tot_start_, marker__alpha_denumerator_tot_end_,
        solver__alpha_denumerator_, marker__scratch_sum_, false);
    alpha_from_num_denum(solver__beta_numerator_, solver__alpha_denumerator_, solver__alpha_,
                         solver__neg_alpha_);
}

void {{ solver.struct_name }}::do_update_step_first() {
    {% for nodetype in solver.node_types %}
    {{nodetype.__name__}}_update_step_first(
        {{node_key(nodetype, "p_k_b")}}, {{num_max_key(nodetype)}},
        solver__alpha_,
        {{node_key(nodetype, "step_k_b")}}, {{num_max_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}
}

void {{ solver.struct_name }}::do_update_step() {
    float* p_kp1;
    float* step_k;
    float* step_kp1;
    {% for nodetype in solver.node_types %}
    p_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "p_k_a")}} : {{node_key(nodetype, "p_k_b")}};
    step_k = pcg_iter_ % 2 ? {{node_key(nodetype, "step_k_b")}} : {{node_key(nodetype, "step_k_a")}};
    step_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "step_k_a")}} : {{node_key(nodetype, "step_k_b")}};
    {{nodetype.__name__}}_update_step_or_update_p(
        step_k, {{num_max_key(nodetype)}},
        p_kp1, {{num_max_key(nodetype)}},
        solver__alpha_,
        step_kp1, {{num_max_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}
}

void {{ solver.struct_name }}::do_update_r_first() {
    zero(marker__r_0_norm2_tot_start_, marker__r_kp1_norm2_tot_end_);

    {% for nodetype in solver.node_types %}
    {{nodetype.__name__}}_update_r_first(
        {{node_key(nodetype, "r_k_a")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "w")}}, {{num_max_key(nodetype)}},
        solver__neg_alpha_,
        {{node_key(nodetype, "r_k_b")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "r_0_norm2_tot")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "r_kp1_norm2_tot")}}, {{num_max_key(nodetype)}},
        {{num_key(nodetype)}}
    );

    {% endfor %}
    pcg_r_0_norm2_ = sum(marker__r_0_norm2_tot_start_, marker__r_0_norm2_tot_end_,
                         marker__scratch_sum_, marker__scratch_sum_+1);
    pcg_r_kp1_norm2_ = sum(marker__r_kp1_norm2_tot_start_, marker__r_kp1_norm2_tot_end_,
                           marker__scratch_sum_, marker__scratch_sum_+1);
}

void {{ solver.struct_name }}::do_update_r() {
    zero(marker__r_kp1_norm2_tot_start_, marker__r_kp1_norm2_tot_end_);

    float* r_k;
    float* r_kp1;
    {% for nodetype in solver.node_types %}
    r_k = pcg_iter_ % 2 ? {{node_key(nodetype, "r_k_b")}} : {{node_key(nodetype, "r_k_a")}};
    r_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "r_k_a")}} : {{node_key(nodetype, "r_k_b")}};

    {{nodetype.__name__}}_update_r(
        r_k, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "w")}}, {{num_max_key(nodetype)}},
        solver__neg_alpha_,
        r_kp1, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "r_kp1_norm2_tot")}}, {{num_blocks_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}
    pcg_r_kp1_norm2_ = sum(marker__r_kp1_norm2_tot_start_, marker__r_kp1_norm2_tot_end_,
                           marker__scratch_sum_, marker__scratch_sum_+1);
}

float {{ solver.struct_name }}::do_retract_score() {
    float* step;

    {% for nodetype in solver.node_types %}
    step = pcg_iter_ % 2 ? {{node_key(nodetype, "step_k_a")}} : {{node_key(nodetype, "step_k_b")}};

    {{nodetype.__name__}}_retract(
        {{node_key(nodetype, "storage_current")}}, {{num_max_key(nodetype)}},
        step, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "storage_check")}}, {{num_max_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}

    zero(marker__res_tot_start_, marker__res_tot_end_);
    {% for fac in solver.factors %}
    {{fac.name}}_score(
        {% for arg, typ in fac.arg_types.items() %}
        {% if fac.isnode[arg] %}
            {{node_key(typ, "storage_check")}},
            {{num_max_key(typ)}},
            {{arg_key(fac, arg, "idx_shared")}},
        {% else %}
            {{arg_key(fac, arg, "data")}},
            {{num_max_key(fac)}},
        {% endif %}
        {% endfor %}
        {{fac_key(fac, "res_tot")}},
        {{num_blocks_key(fac)}},
        {{num_key(fac)}}
    );
    {% endfor %}
    return 0.5 * sum(marker__res_tot_start_, marker__res_tot_end_, marker__scratch_sum_,
                     marker__scratch_sum_+1);
}

void {{ solver.struct_name }}::do_beta() {
    zero(marker__beta_numerator_tot_start_, marker__beta_numerator_tot_end_);
    float* r_k;
    {% for nodetype in solver.node_types %}
    r_k = pcg_iter_ % 2 ? {{node_key(nodetype, "r_k_b")}} : {{node_key(nodetype, "r_k_a")}};

    {{nodetype.__name__}}_alpha_denumerator_or_beta_nummerator(
        r_k, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "z")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "beta_numerator_tot")}}, {{num_blocks_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}
    sum(marker__beta_numerator_tot_start_, marker__beta_numerator_tot_end_,
        solver__beta_numerator_, marker__scratch_sum_, false);
    beta_from_num_denum(solver__beta_numerator_, solver__alpha_numerator_, solver__beta_);
}

void {{ solver.struct_name }}::do_update_p() {
    float* p_k;
    float* p_kp1;
    {% for nodetype in solver.node_types %}
    p_k = pcg_iter_ % 2 ? {{node_key(nodetype, "p_k_b")}} : {{node_key(nodetype, "p_k_a")}};
    p_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "p_k_a")}} : {{node_key(nodetype, "p_k_b")}};
    {{nodetype.__name__}}_update_step_or_update_p(
        {{node_key(nodetype, "z")}}, {{num_max_key(nodetype)}},
        p_k, {{num_max_key(nodetype)}},
        solver__beta_,
        p_kp1, {{num_max_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}
}

float {{ solver.struct_name }}::get_pred_decrease() {
    zero(marker__pred_decrease_tot_start_, marker__pred_decrease_tot_end_);
    float* step_kp1;
    float* r_kp1;
    {% for nodetype in solver.node_types %}
    step_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "step_k_a")}} : {{node_key(nodetype, "step_k_b")}};
    r_kp1 = pcg_iter_ % 2 ? {{node_key(nodetype, "r_k_a")}} : {{node_key(nodetype, "r_k_b")}};
    {{nodetype.__name__}}_pred_decrease(
        step_kp1, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "precond_diag")}}, {{num_max_key(nodetype)}},
        solver__current_diag_,
        r_kp1, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "r_0")}}, {{num_max_key(nodetype)}},
        {{node_key(nodetype, "pred_decrease_tot")}}, {{num_blocks_key(nodetype)}},
        {{num_key(nodetype)}}
    );
    {% endfor %}
    return sum(marker__pred_decrease_tot_start_, marker__pred_decrease_tot_end_,
               marker__scratch_sum_, marker__scratch_sum_+1);
}

void {{ solver.struct_name }}::finish_indices() {
    {% for fac in solver.factors %}
    {% for arg, argtype in fac.node_arg_types.items() %}
    {% if arg != fac.smallest_node %}
    select_index({{arg_key(fac, fac.smallest_node, "idx_target")}},
                 {{arg_key(fac, arg, "idx_argsort")}},
                 {{arg_key(fac, arg, "idx_jp_target")}},
                 {{num_key(fac)}});
    {% endif %}
    {% endfor %}
    {% endfor %}
    indices_valid_ = true;
}

{% for nodetype in solver.node_types %}
void {{ solver.struct_name }}::set_{{num_key(nodetype)[:-1]}}(const size_t num) {
    if (num > {{num_max_key(nodetype)}}) {
        throw std::runtime_error(std::to_string(num) + " > {{num_max_key(nodetype)}}");
    }
    {{num_key(nodetype)}} = num;
}

void {{ solver.struct_name }}::set_{{nodetype.__name__}}_nodes_from_stacked_host(
    const float* const data, const size_t offset, const size_t num) {
  if (offset + num > {{num_key(nodetype)}}){
    throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
  }
  cudaMemcpy(marker__scratch_inout_, data, {{Ops.storage_dim(nodetype)}} * num * sizeof(float),
             cudaMemcpyHostToDevice);
  {{nodetype.__name__}}_stacked_to_caspar(marker__scratch_inout_,
                                          {{node_key(nodetype, "storage_current")}},
                                          {{num_max_key(nodetype)}}, offset, num);
}

void {{ solver.struct_name }}::set_{{nodetype.__name__}}_nodes_from_stacked_device(
    const float* const data, const size_t offset, const size_t num) {
  if (offset + num > {{num_key(nodetype)}}){
    throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
  }
  {{nodetype.__name__}}_stacked_to_caspar(
      data, {{node_key(nodetype, "storage_current")}}, {{num_max_key(nodetype)}}, offset, num);
}

void {{solver.struct_name}}::get_{{nodetype.__name__}}_nodes_to_stacked_host(
    float* const data, const size_t offset, const size_t num) {
  if (offset + num > {{num_key(nodetype)}}){
    throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
  }
  {{nodetype.__name__}}_caspar_to_stacked({{node_key(nodetype, "storage_current")}},
                                          marker__scratch_inout_, {{num_max_key(nodetype)}},
                                          offset, num);
  cudaMemcpy(data, marker__scratch_inout_, {{Ops.storage_dim(nodetype)}} * num * sizeof(float),
             cudaMemcpyDeviceToHost);
}

void {{solver.struct_name}}::get_{{nodetype.__name__}}_nodes_to_stacked_device(
    float* const data, const size_t offset, const size_t num) {
  if (offset + num > {{num_key(nodetype)}}){
    throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
  }
  {{nodetype.__name__}}_caspar_to_stacked({{node_key(nodetype, "storage_current")}}, data,
                                          {{num_max_key(nodetype)}}, offset, num);
}

{% endfor%}

{% for fac in solver.factors %}
void {{ solver.struct_name }}::set_{{num_key(fac)[:-1]}}(const size_t num) {
  if (num > {{num_max_key(fac)}}){
    throw std::runtime_error(std::to_string(num) + " > {{num_max_key(fac)}}");
  }
  {{num_key(fac)}} = num;
}
{% for arg, argtype in fac.node_arg_types.items() %}

void {{ solver.struct_name }}::set_{{fac.name}}_{{arg}}_indices_from_host(
    const unsigned int* const indices, size_t num) {
  if (num != {{num_key(fac)}}){
    throw std::runtime_error(
        std::to_string(num)
        + " != {{num_key(fac)}}. Use set_{{fac.name}}_num before setting indices.");
  }
  cudaMemcpy((unsigned int*)marker__scratch_inout_, indices, num * sizeof(unsigned int),
             cudaMemcpyHostToDevice);
  set_{{fac.name}}_{{arg}}_indices_from_device((unsigned int*)marker__scratch_inout_, num);
}

void {{ solver.struct_name }}::set_{{fac.name}}_{{arg}}_indices_from_device(
    const unsigned int* const indices, size_t num) {
  indices_valid_ = false;

  if (num != {{num_key(fac)}}){
    throw std::runtime_error(
        std::to_string(num)
        + " != {{num_key(fac)}}. Use set_{{fac.name}}_num before setting indices.");
  }

  size_t tmp_size = sort_indices_get_tmp_nbytes(num);
  if (tmp_size + num > scratch_inout_size_) {
      throw std::runtime_error("Scratch_inout_size too small. tmp_size: " + std::to_string(tmp_size) +
                               ", num: " + std::to_string(num) +
                               ", scratch_inout_size_: " + std::to_string(scratch_inout_size_));
  }
  sort_indices(marker__scratch_inout_ + num, tmp_size, indices,
              {{arg_key(fac, arg, "idx_sorted")}},
              {{arg_key(fac, arg, "idx_target")}},
              {{arg_key(fac, arg, "idx_argsort")}},
              num);
  unsigned int idx_max;
  cudaMemcpy(&idx_max, {{arg_key(fac, arg, "idx_sorted")}} + num - 1, sizeof(unsigned int),
             cudaMemcpyDeviceToHost);
  if (idx_max >= {{num_key(argtype)}}) {
    throw std::runtime_error(std::to_string(idx_max) + " >= {{num_key(argtype)}} ("
                             + std::to_string({{num_key(argtype)}}) + ")");
  }

  shared_indices(indices, {{arg_key(fac, arg, "idx_shared")}}, num);
  shared_indices({{arg_key(fac, arg, "idx_sorted")}},
                  {{arg_key(fac, arg, "idx_sorted_shared")}},
                  num);
}
{% endfor %}
{% for arg, argtype in fac.const_arg_types.items() %}
void {{ solver.struct_name }}::set_{{fac.name}}_{{arg}}_data_from_stacked_host(
    const float* const data, size_t offset, size_t num) {
  if (offset + num > {{num_max_key(fac)}}) {
      throw std::runtime_error(std::to_string(offset + num) + " > {{num_max_key(fac)}}");
  }
  cudaMemcpy(marker__scratch_inout_, data, {{Ops.storage_dim(argtype)}} * num * sizeof(float),
             cudaMemcpyHostToDevice);
  {{argtype.__name__}}_stacked_to_caspar(marker__scratch_inout_, {{arg_key(fac, arg, "data")}},
                                         {{num_max_key(fac)}}, offset, num);
}

void {{ solver.struct_name }}::set_{{fac.name}}_{{arg}}_data_from_stacked_device(
    const float* const data, size_t offset, size_t num) {
  if (offset + num > {{num_max_key(fac)}}){
    throw std::runtime_error(std::to_string(offset + num) + " > {{num_max_key(fac)}}");
  }
  {{argtype.__name__}}_stacked_to_caspar(data, {{arg_key(fac, arg, "data")}}, {{num_max_key(fac)}},
                                         offset, num);
}

{% endfor%}
{% endfor%}

size_t {{ solver.struct_name }}::get_nbytes() {
    size_t offset = 0;
    {% for name, field in solver.fields.items()%}
    increment_offset<{{field.dtype}}>(offset, {{field.dim_real}} * {{field.num_key}},
                                      {{field.alignment}});
    {% endfor %}

    return offset;
}

} // namespace caspar
